#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Callable, Dict, List, Sequence, Tuple, Union

from .....ir import *

from .... import func
from .... import linalg
from .... import math
from .... import arith
from .... import complex
from ...._ods_common import (
    get_op_result_or_value as _get_op_result_or_value,
    get_op_results_or_values as _get_op_results_or_values,
)

from .scalar_expr import *
from .config import *
from .comprehension import *
import numpy as np

__all__ = [
    "emit_generic_structured_op",
    "emit_named_structured_op",
    "ValueList",
]

# Type aliases.
ValueList = Union[Sequence[Value], OpResultList]


def isa(cls: Type, ty: Type):
    try:
        cls(ty)
        return True
    except ValueError:
        return False


def prepare_common_structured_op(
    op_config: LinalgStructuredOpConfig,
    *ins: Value,
    outs: ValueList,
    **attrs: Union[Sequence[int], TypeFnType],
):
    all_arg_defs = op_config.ordered_operands
    in_arg_defs = [
        d
        for d in all_arg_defs
        if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR]
    ]
    out_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR]
    index_attr_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR]
    fn_attr_arg_defs = [
        d
        for d in all_arg_defs
        if d.kind
        in [
            OperandKind.UNARY_FN_ATTR,
            OperandKind.BINARY_FN_ATTR,
            OperandKind.TYPE_FN_ATTR,
        ]
    ]

    # Verify outs is a sequence or a list of results.
    if not isinstance(outs, (Sequence, OpResultList)):
        raise ValueError(
            f"Expected named argument outs to have type Sequence or "
            f"OpResultLis but got {type(outs)}"
        )

    # Arity validation.
    if len(ins) != len(in_arg_defs):
        raise ValueError(
            f"Expected {len(in_arg_defs)} inputs but got " f"{len(ins)} for {op_config}"
        )
    if outs and len(outs) != len(out_arg_defs):
        raise ValueError(
            f"Expected {len(out_arg_defs)} outputs but got "
            f"{len(outs)} for {op_config}"
        )

    # Compute a replacement list for all index attribute symbols.
    expressions = []  # type: Sequence[AffineExpr]
    replacements = []  # type: Sequence[AffineExpr]
    for index_attr in index_attr_arg_defs:
        index_attr_vals = index_attr.operand_def.default_indices
        if index_attr.name in attrs:
            index_attr_vals = attrs.get(index_attr.name)
        assert index_attr_vals, "Index attribute has no value"
        if not all(isinstance(value, int) for value in index_attr_vals):
            raise ValueError(
                f"Attribute {index_attr.name} needs to be of type "
                f"Sequence[int] but got {type(index_attr_vals)}"
            )
        results = index_attr.index_attr_map.results  # type: AffineExprList
        if len(index_attr_vals) != len(results):
            raise ValueError(
                f"Attribute {index_attr.name} has length {len(results)} "
                f"but got {len(index_attr_vals)} values"
            )
        for expr, value in zip(results, index_attr_vals):
            expressions.append(expr)
            replacements.append(AffineConstantExpr.get(value))

    # Replace all index attribute symbols by their value.
    # TODO: Add support for shape symbols.
    indexing_maps = []  # type: Sequence[AffineMap]
    for curr in op_config.indexing_maps:
        for expression, replacement in zip(expressions, replacements):
            curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols)
        indexing_maps.append(curr)

    # TODO: Linalg verification does not currently allow symbols.
    # Compress them for now and verify none are left.
    indexing_maps = AffineMap.compress_unused_symbols(indexing_maps, Context.current)
    if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps):
        raise ValueError(
            f"Expected indexing_maps to use no symbols after "
            f"replacement and compression but got {indexing_maps}"
        )

    outs, out_types = _infer_structured_outs(
        op_config, in_arg_defs, ins, out_arg_defs, outs
    )

    result_types = [t for t in out_types if isa(RankedTensorType, t)]

    # Initialize the type dictionary with the predefined types.
    type_mapping = dict()  # type: Dict[str, Type]
    type_mapping["F32"] = F32Type.get()
    type_mapping["F64"] = F64Type.get()
    type_mapping["I32"] = IntegerType.get_signless(32)
    type_mapping["I64"] = IntegerType.get_signless(64)

    # Extract type vars for input/output based types.
    block_arg_types = list()  # type: List[Type]
    for arg_def, arg_element_type in zip(
        in_arg_defs + out_arg_defs, _get_types_from_values(*ins, *outs)
    ):
        _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types)

    # Emit the generic op.
    # TODO: Support emission of pure memref form.
    indexing_maps_attr = ArrayAttr.get([AffineMapAttr.get(am) for am in indexing_maps])
    iterator_types_attr = ArrayAttr.get(
        [
            Attribute.parse(f"#linalg.iterator_type<{s}>")
            for s in op_config.iterator_types
        ]
    )

    # Compute the index attributes used when emitting a named structured op.
    index_attrs = {}  # type: Dict[str, DenseElementAttr]
    for index_attr in index_attr_arg_defs:
        index_attr_vals = attrs.get(index_attr.name)
        # Only forward attributes set to a non-default value.
        if index_attr_vals:
            array = np.array(index_attr_vals, dtype=np.int64)
            index_attrs[index_attr.name] = DenseElementsAttr.get(array)

    # Compute the function attribute mapping.
    fn_attr_mapping = {}
    for fn_attr in fn_attr_arg_defs:
        attr_val = fn_attr.operand_def.default_fn
        attr_kind = fn_attr.kind
        if fn_attr.name in attrs:
            fn = attrs.get(fn_attr.name)
            if attr_kind == OperandKind.UNARY_FN_ATTR:
                if not isinstance(fn, UnaryFnType):
                    raise ValueError(
                        f"Attribute {fn_attr.name} needs to be of type "
                        f"UnaryFnType but got {type(attr_val)}"
                    )
            elif attr_kind == OperandKind.BINARY_FN_ATTR:
                if not isinstance(fn, BinaryFnType):
                    raise ValueError(
                        f"Attribute {fn_attr.name} needs to be of type "
                        f"BinaryFnType but got {type(attr_val)}"
                    )
            else:
                if not isinstance(fn, TypeFnType):
                    raise ValueError(
                        f"Attribute {fn_attr.name} needs to be of type "
                        f"TypeFnType but got {type(attr_val)}"
                    )
            attr_val = fn.fn_name
        assert attr_val, "Function attribute has no value"
        fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind)

    return (
        all_arg_defs,
        in_arg_defs,
        out_arg_defs,
        outs,
        result_types,
        type_mapping,
        indexing_maps_attr,
        iterator_types_attr,
        index_attrs,
        fn_attr_mapping,
        block_arg_types,
    )


def emit_generic_structured_op(
    op_config: LinalgStructuredOpConfig,
    *ins: Value,
    outs: ValueList,
    **attrs: Sequence[int],
):
    (
        all_arg_defs,
        in_arg_defs,
        out_arg_defs,
        outs,
        result_types,
        type_mapping,
        indexing_maps_attr,
        iterator_types_attr,
        index_attrs,
        fn_attr_mapping,
        block_arg_types,
    ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs)

    # An operation that accesses only scalars and scalar/rank zero tensors is
    # rank polymorhpic. We implement rank polymorphism by generating different
    # indexing maps and iterators that match the rank of the first output tensor.
    # An operation is rank polymorphic if the iteration domain has rank zero.
    if not iterator_types_attr:
        rank = ShapedType(outs[0].type).rank
        iterator_types_attr = ArrayAttr.get(
            [Attribute.parse("#linalg.iterator_type<parallel>")] * rank
        )
        scalar_map = AffineMap.get(rank, 0, [])
        tensor_map = AffineMap.get_identity(rank)
        indexing_maps = []
        for arg_def in all_arg_defs:
            if arg_def.operand_def.kind == OperandKind.SCALAR:
                indexing_maps.append(scalar_map)
            if arg_def.operand_def.is_tensor():
                idx = arg_def.operand_def.registered_index
                if idx < len(ins) and ShapedType(ins[idx].type).rank == 0:
                    indexing_maps.append(scalar_map)
                else:
                    indexing_maps.append(tensor_map)
        indexing_maps_attr = ArrayAttr.get(
            [AffineMapAttr.get(am) for am in indexing_maps]
        )

    generic_op = linalg.GenericOp(
        result_tensors=result_types,
        inputs=ins,
        outputs=outs,
        indexing_maps=indexing_maps_attr,
        iterator_types=iterator_types_attr,
        doc=None,  # TODO: Make optional.
        library_call=None,
    )  # TODO: Make optional.

    # Construct the body.
    block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs)
    block = generic_op.regions[0].blocks.append(*block_arg_types)
    block_arg_mapping = dict(zip(block_arg_names, block.arguments))
    with InsertionPoint(block):
        body_builder = _BodyBuilder(type_mapping, block_arg_mapping, fn_attr_mapping)
        for assignment in op_config.assignments:
            body_builder.assign(assignment)
        body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))

    if len(result_types) == 1:
        return generic_op.result
    else:
        return generic_op.results


def emit_named_structured_op(
    op_config: LinalgStructuredOpConfig,
    op_name: str,
    op_class_name: str,
    *ins: Value,
    outs: ValueList,
    **attrs: Sequence[int],
):
    (
        all_arg_defs,
        in_arg_defs,
        out_arg_defs,
        outs,
        result_types,
        type_mapping,
        indexing_maps_attr,
        iterator_types_attr,
        index_attrs,
        fn_attr_mapping,
        block_arg_types,
    ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs)

    # If we get here, there must exist a builtin class `op_class_name`.
    ctx = Context.current
    fully_qualified_name = "linalg." + op_name
    if (
        not ctx.is_registered_operation(fully_qualified_name)
        or not op_class_name in linalg.__dict__.keys()
    ):
        raise NotImplementedError(
            f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}"
        )

    # Set the index attributes used to compute the indexing maps.
    named_op = getattr(linalg, op_class_name)(result_types, ins, outs)
    for name, value in index_attrs.items():
        named_op.operation.attributes[name] = value

    # Compute the function attributes by combining operand kind and function name.
    for name, (fn_name, kind) in fn_attr_mapping.items():
        assert kind.name.lower().endswith("_attr")
        enum_name = kind.name.lower()[:-5]
        named_op.operation.attributes[name] = Attribute.parse(
            f"#linalg.{enum_name}<{fn_name}>"
        )

    linalg.fill_builtin_region(named_op.operation)

    if len(result_types) == 1:
        return named_op.result
    else:
        return named_op.results


class _BodyBuilder:
    """Constructs a structured op body by evaluating assignments."""

    def __init__(
        self,
        type_mapping: Dict[str, Type],
        block_arg_mapping: Dict[str, Value],
        fn_attr_mapping: Dict[str, str],
    ):
        self.type_mapping = type_mapping
        self.block_arg_mapping = block_arg_mapping
        self.fn_attr_mapping = fn_attr_mapping
        self.yield_mapping = dict()  # type: Dict[str, Value]

    def assign(self, assignment: ScalarAssign):
        if assignment.arg in self.yield_mapping:
            raise ValueError(
                f"Multiple assignments to the same argument are forbidden: "
                f"{assignment}"
            )
        self.yield_mapping[assignment.arg] = self.expression(assignment.value)

    def expression(self, expr: ScalarExpression) -> Value:
        if expr.scalar_arg:
            try:
                return self.block_arg_mapping[expr.scalar_arg.arg]
            except KeyError:
                raise ValueError(
                    f"Argument {expr.scalar_arg.arg} is not bound for "
                    f"this structured op."
                )
        elif expr.scalar_const:
            value_attr = Attribute.parse(expr.scalar_const.value)
            return arith.ConstantOp(value_attr.type, value_attr).result
        elif expr.scalar_index:
            dim_attr = IntegerAttr.get(
                IntegerType.get_signless(64), expr.scalar_index.dim
            )
            return linalg.IndexOp(dim_attr).result
        elif expr.scalar_fn:
            kind = expr.scalar_fn.kind.name.lower()
            fn_name = expr.scalar_fn.fn_name
            if expr.scalar_fn.attr_name:
                fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name]
            fn = self._get_function(f"_{kind}_{fn_name}")
            operand_values = [
                self.expression(operand) for operand in expr.scalar_fn.operands
            ]
            if expr.scalar_fn.kind == FunctionKind.TYPE:
                operand_values = [expr.scalar_fn.type_var.name] + operand_values
            return fn(*operand_values)
        raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")

    def yield_outputs(self, *output_names: str):
        output_values = []
        for n in output_names:
            try:
                output_values.append(self.yield_mapping[n])
            except KeyError:
                raise ValueError(
                    f"Body assignments do not assign all outputs: " f"missing '{n}'"
                )
        linalg.YieldOp(output_values)

    def _get_function(self, fn_name: str) -> Callable:
        try:
            fn = getattr(self, f"{fn_name}")
        except AttributeError:
            raise ValueError(f"Function '{fn_name}' is not a known function")
        return fn

    def _cast(
        self, type_var_name: str, operand: Value, is_unsigned_cast: bool = False
    ) -> Value:
        try:
            to_type = self.type_mapping[type_var_name]
        except KeyError:
            raise ValueError(
                f"Unbound type variable '{type_var_name}' ("
                f"expected one of {self.type_mapping.keys()}"
            )
        if operand.type == to_type:
            return operand
        if _is_integer_type(to_type):
            return self._cast_to_integer(to_type, operand, is_unsigned_cast)
        elif _is_floating_point_type(to_type):
            return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)

    def _cast_to_integer(
        self, to_type: Type, operand: Value, is_unsigned_cast: bool
    ) -> Value:
        to_width = IntegerType(to_type).width
        operand_type = operand.type
        if _is_floating_point_type(operand_type):
            if is_unsigned_cast:
                return arith.FPToUIOp(to_type, operand).result
            return arith.FPToSIOp(to_type, operand).result
        if _is_index_type(operand_type):
            return arith.IndexCastOp(to_type, operand).result
        # Assume integer.
        from_width = IntegerType(operand_type).width
        if to_width > from_width:
            if is_unsigned_cast:
                return arith.ExtUIOp(to_type, operand).result
            return arith.ExtSIOp(to_type, operand).result
        elif to_width < from_width:
            return arith.TruncIOp(to_type, operand).result
        raise ValueError(
            f"Unable to cast body expression from {operand_type} to " f"{to_type}"
        )

    def _cast_to_floating_point(
        self, to_type: Type, operand: Value, is_unsigned_cast: bool
    ) -> Value:
        operand_type = operand.type
        if _is_integer_type(operand_type):
            if is_unsigned_cast:
                return arith.UIToFPOp(to_type, operand).result
            return arith.SIToFPOp(to_type, operand).result
        # Assume FloatType.
        to_width = _get_floating_point_width(to_type)
        from_width = _get_floating_point_width(operand_type)
        if to_width > from_width:
            return arith.ExtFOp(to_type, operand).result
        elif to_width < from_width:
            return arith.TruncFOp(to_type, operand).result
        raise ValueError(
            f"Unable to cast body expression from {operand_type} to " f"{to_type}"
        )

    def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value:
        return self._cast(type_var_name, operand, False)

    def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
        return self._cast(type_var_name, operand, True)

    def _unary_exp(self, x: Value) -> Value:
        if _is_floating_point_type(x.type):
            return math.ExpOp(x).result
        raise NotImplementedError("Unsupported 'exp' operand: {x}")

    def _unary_log(self, x: Value) -> Value:
        if _is_floating_point_type(x.type):
            return math.LogOp(x).result
        raise NotImplementedError("Unsupported 'log' operand: {x}")

    def _unary_abs(self, x: Value) -> Value:
        if _is_floating_point_type(x.type):
            return math.AbsFOp(x).result
        raise NotImplementedError("Unsupported 'abs' operand: {x}")

    def _unary_ceil(self, x: Value) -> Value:
        if _is_floating_point_type(x.type):
            return math.CeilOp(x).result
        raise NotImplementedError("Unsupported 'ceil' operand: {x}")

    def _unary_floor(self, x: Value) -> Value:
        if _is_floating_point_type(x.type):
            return math.FloorOp(x).result
        raise NotImplementedError("Unsupported 'floor' operand: {x}")

    def _unary_negf(self, x: Value) -> Value:
        if _is_floating_point_type(x.type):
            return arith.NegFOp(x).result
        if _is_complex_type(x.type):
            return complex.NegOp(x).result
        raise NotImplementedError("Unsupported 'negf' operand: {x}")

    def _binary_add(self, lhs: Value, rhs: Value) -> Value:
        if _is_floating_point_type(lhs.type):
            return arith.AddFOp(lhs, rhs).result
        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
            return arith.AddIOp(lhs, rhs).result
        if _is_complex_type(lhs.type):
            return complex.AddOp(lhs, rhs).result
        raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")

    def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
        if _is_floating_point_type(lhs.type):
            return arith.SubFOp(lhs, rhs).result
        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
            return arith.SubIOp(lhs, rhs).result
        if _is_complex_type(lhs.type):
            return complex.SubOp(lhs, rhs).result
        raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")

    def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
        if _is_floating_point_type(lhs.type):
            return arith.MulFOp(lhs, rhs).result
        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
            return arith.MulIOp(lhs, rhs).result
        if _is_complex_type(lhs.type):
            return complex.MulOp(lhs, rhs).result
        raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")

    def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
        if _is_floating_point_type(lhs.type):
            return arith.MaximumFOp(lhs, rhs).result
        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
            return arith.MaxSIOp(lhs, rhs).result
        raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")

    def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
        if _is_floating_point_type(lhs.type):
            return arith.MaximumFOp(lhs, rhs).result
        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
            return arith.MaxUIOp(lhs, rhs).result
        raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")

    def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
        if _is_floating_point_type(lhs.type):
            return arith.MinimumFOp(lhs, rhs).result
        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
            return arith.MinSIOp(lhs, rhs).result
        raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")

    def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
        if _is_floating_point_type(lhs.type):
            return arith.MinimumFOp(lhs, rhs).result
        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
            return arith.MinUIOp(lhs, rhs).result
        raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")


def _infer_structured_outs(
    op_config: LinalgStructuredOpConfig,
    in_arg_defs: Sequence[OperandDefConfig],
    ins: Sequence[Value],
    out_arg_defs: Sequence[OperandDefConfig],
    outs: Union[Sequence[Value], OpResultList],
) -> Tuple[ValueList, List[Type]]:
    """Infers implicit outs and output types.

    Respects existing contents of outs if not empty.

    Returns:
      normalized outs, output types
    """
    # If outs were explicitly provided, we accept them verbatim.
    if outs:
        return outs, [out.type for out in outs]

    raise NotImplementedError(
        f"Output tensor inference not yet supported for " "structured ops"
    )


def _get_types_from_values(*values: Value) -> Sequence[Type]:
    types = []
    for v in values:
        types.append(v.type)
    return types


def _get_operand_def_names(*operand_configs: OperandDefConfig) -> Sequence[str]:
    return [odc.operand_def.name for odc in operand_configs]


def _add_type_mapping(
    operand_config: OperandDefConfig,
    operand_type: Type,
    type_mapping: Dict[str, Type],
    block_arg_types: Sequence[Type],
):
    element_or_self_type = operand_type
    # Get the element type for tensor operands and the type itself for scalars.
    if operand_config.shape_map:
        try:
            element_or_self_type = ShapedType(operand_type).element_type
        except Exception as e:
            raise ValueError(f"Expected ShapedType but got {operand_type}") from e
    name = operand_config.type_var.name
    if name in type_mapping:
        if type_mapping[name] != element_or_self_type:
            raise ValueError(
                f"Cannot overwrite type mapping {name} = "
                f"{type_mapping[name]} by type {element_or_self_type}"
            )
    type_mapping[name] = element_or_self_type
    block_arg_types.append(element_or_self_type)


def _is_complex_type(t: Type) -> bool:
    return ComplexType.isinstance(t)


def _is_floating_point_type(t: Type) -> bool:
    # TODO: Create a FloatType in the Python API and implement the switch
    # there.
    return (
        F64Type.isinstance(t)
        or F32Type.isinstance(t)
        or F16Type.isinstance(t)
        or BF16Type.isinstance(t)
    )


def _is_integer_type(t: Type) -> bool:
    return IntegerType.isinstance(t)


def _is_index_type(t: Type) -> bool:
    return IndexType.isinstance(t)


def _get_floating_point_width(t: Type) -> int:
    # TODO: Create a FloatType in the Python API and implement the switch
    # there.
    if F64Type.isinstance(t):
        return 64
    if F32Type.isinstance(t):
        return 32
    if F16Type.isinstance(t):
        return 16
    if BF16Type.isinstance(t):
        return 16
    raise NotImplementedError(f"Unhandled floating point type switch {t}")
